{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2022, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}

# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

import typing as T
import unittest

import numpy as np

import sym


class CamPackageTest(unittest.TestCase):
    """
    Tests for Python runtime camera types. Mostly checking basic generation logic
    since the math is tested comprehensively in symbolic form.
    """

    _DISTORTION_COEFF_VALS = {{ _DISTORTION_COEFF_VALS }}

    @staticmethod
    def cam_cal_from_points(
        cam_cls: T.Type, focal_length: T.Sequence[float], principal_point: T.Sequence[float]
    ) -> T.Any:
        return cam_cls(
            focal_length=focal_length,
            principal_point=principal_point,
            **CamPackageTest._DISTORTION_COEFF_VALS.get(cam_cls.__name__, {}))

    {% for cls in all_types %}
    def test_getters_{{ cls.__name__ }}(self) -> None:
        focal_length = [1.0, 2.0]
        principal_point = [3.0, 4.0]
        cam_cal = self.cam_cal_from_points(
            sym.{{ cls.__name__ }}, focal_length=focal_length, principal_point=principal_point
        )

        np.testing.assert_array_equal(np.array(focal_length), cam_cal.focal_length())
        np.testing.assert_array_equal(np.array(principal_point), cam_cal.principal_point())

        with self.subTest("Getters are compatible with the constructor"):
            new_cam_cal = self.cam_cal_from_points(
                sym.{{ cls.__name__ }},
                focal_length=cam_cal.focal_length(),
                principal_point=cam_cal.principal_point()
            )
            self.assertEqual(cam_cal.data, new_cam_cal.data)
            for x in new_cam_cal.data:
                self.assertIsInstance(x, float)

    def test_storage_ops_{{ cls.__name__ }}(self) -> None:
        cam_cal = self.cam_cal_from_points(
            sym.{{ cls.__name__ }}, focal_length=[1.0, 2.0], principal_point=[3.0, 4.0]
        )
        storage = cam_cal.to_storage()

        self.assertEqual(cam_cal.storage_dim(), {{ cls.storage_dim() }})
        self.assertEqual(len(storage), {{ cls.storage_dim() }})

        cam_cal_copy = sym.{{ cls.__name__ }}.from_storage(storage)

        self.assertEqual(cam_cal, cam_cal_copy)

        cam_cal_different = sym.{{ cls.__name__ }}.from_storage([x + 1 for x in storage])

        self.assertNotEqual(cam_cal, cam_cal_different)

    {% set symbolic_cam_cal = cam_cal_from_points(
        cam_cls=cls,
        focal_length=[1.0, 2.0],
        principal_point=[3.0, 4.0],
    )%}
    {% set epsilon = 1e-8 %}

    def test_lie_group_ops_{{ cls.__name__ }}(self) -> None:
        # NOTE(brad): The magic numbers come from the jinja template, and are the outputs of
        # of the symbolic class's methods.
        cam_cal = sym.{{ cls.__name__ }}.from_storage({{ symbolic_cam_cal.to_storage() | map("float") | list }})

        tangent = cam_cal.to_tangent(epsilon={{ epsilon }})

        {% set tangent_dim = ops.LieGroupOps.tangent_dim(cls) %}
        # Test tangent_dim is correct
        self.assertEqual(cam_cal.tangent_dim(), {{ tangent_dim }})

        # Test to_tangent is correct
        {% set symbolic_tangent = ops.LieGroupOps.to_tangent(symbolic_cam_cal, epsilon=epsilon) %}
        np.testing.assert_allclose(
            tangent,
            np.array({{ symbolic_tangent | map("float") | list }})
        )

        # Test from_tangent is correct
        np.testing.assert_allclose(
            sym.{{ cls.__name__ }}.from_tangent(vec=tangent, epsilon={{ epsilon }}).to_storage(),
            {{ ops.LieGroupOps.from_tangent(cls, vec=symbolic_tangent, epsilon=epsilon).to_storage() | map("float") | list }}
        )

        {% set second_symbolic_cam_cal = cam_cal_from_points(
            cam_cls=cls,
            focal_length=[3.3, 5.5],
            principal_point=[2.4, 3.8],
        )%}
        second_cam_cal = sym.{{ cls.__name__ }}.from_storage({{ second_symbolic_cam_cal.to_storage() | map("float") | list }})

        # Test retract is correct
        np.testing.assert_allclose(
            second_cam_cal.retract(vec=tangent, epsilon={{ epsilon }}).to_storage(),
            {{ ops.LieGroupOps.retract(second_symbolic_cam_cal, vec=symbolic_tangent, epsilon=epsilon).to_storage() | map("float") | list }}
        )

        # Test local_coordinates is correct
        np.testing.assert_allclose(
            second_cam_cal.local_coordinates(cam_cal, epsilon={{ epsilon }}),
            np.array({{ ops.LieGroupOps.local_coordinates(second_symbolic_cam_cal, symbolic_cam_cal, epsilon=epsilon) | map("float") | list }})
        )

    def test_pixel_from_camera_point_{{ cls.__name__ }}(self) -> None:
        # NOTE(brad): The magic numbers come from the jinja template, and are the outputs
        # of the symbolic class's methods.
        {% set point = Matrix([0.6, 0.8, 0.2]) %}
        {% set pre_eval_pixel, pre_eval_is_valid = symbolic_cam_cal.pixel_from_camera_point(point, epsilon=epsilon)%}
        {% set pixel = pre_eval_pixel.evalf() %}
        {% set is_valid = pre_eval_is_valid.evalf() %}

        cam_cal = sym.{{ cls.__name__ }}.from_storage({{ symbolic_cam_cal.to_storage() | map("float") | list }})
        point = np.array({{ point.to_storage() | map("float") | list }})

        pixel, is_valid = cam_cal.pixel_from_camera_point(point=point, epsilon={{ epsilon }})

        np.testing.assert_allclose(pixel, np.array({{ pixel.to_storage() | map("float") | map("round", 14) | list }}))
        self.assertEqual(is_valid, {{ is_valid | float }})

        j_pixel, j_is_valid, pixel_D_cal, pixel_D_point = cam_cal.pixel_from_camera_point_with_jacobians(point=point, epsilon={{ epsilon }})

        np.testing.assert_allclose(j_pixel, np.array({{ pixel.to_storage() | map("float") | map("round", 14) | list }}))
        self.assertEqual(j_is_valid, {{ is_valid | float }})
        self.assertEqual(pixel_D_cal.shape, (2, {{ symbolic_cam_cal.parameters() | length }}))
        self.assertEqual(pixel_D_point.shape, (2, 3))

    {% if cls.has_camera_ray_from_pixel() %}
    def test_camera_ray_from_pixel_{{ cls.__name__ }}(self) -> None:
        # NOTE(brad): The magic numbers come from the jinja template, and are the outputs of
        # of the symbolic class's methods.
        {% set pixel = Matrix([0.6, 0.8]) %}
        {% set pre_eval_ray, pre_eval_is_valid = symbolic_cam_cal.camera_ray_from_pixel(pixel, epsilon=epsilon)%}
        {% set ray = pre_eval_ray.evalf() %}
        {% set is_valid = pre_eval_is_valid.evalf() %}

        cam_cal = sym.{{ cls.__name__ }}.from_storage({{ symbolic_cam_cal.to_storage() | map("float") | list }})
        pixel = np.array({{ pixel.to_storage() | map("float") | list }})

        ray, is_valid = cam_cal.camera_ray_from_pixel(pixel=pixel, epsilon={{ epsilon }})

        np.testing.assert_allclose(ray, np.array({{ ray.to_storage() | map("float") | list }}))
        self.assertEqual(is_valid, {{ is_valid | float }})

        j_ray, j_is_valid, point_D_cal, point_D_pixel = cam_cal.camera_ray_from_pixel_with_jacobians(pixel=pixel, epsilon={{ epsilon }})

        np.testing.assert_allclose(j_ray, np.array({{ ray.to_storage() | map("float") | list }}))
        self.assertEqual(j_is_valid, {{ is_valid | float }})
        self.assertEqual(point_D_cal.shape, (3, {{ symbolic_cam_cal.parameters() | length }}))
        self.assertEqual(point_D_pixel.shape, (3, 2))
    {% endif %}


    {% endfor %}

if __name__ == "__main__":
    unittest.main()
